library(data.table)
library(ggplot2)
library(knitr)
library(ggrepel)
library(RColorBrewer)
library(DESeq2)
library(rnaseqGene)
library(plotly)
library(Rtsne)

rm(list = ls())

setDTthreads(8)

data.output.dir <- file.path(here::here(
  '..','..',
  's3-roybal-tcsl',
  'lenti_screen_compiled_data','data'))

# set this to true to re-run, else will load from s3
rerun_deseq <- F

Run DESeq2 across all bin combos.

load(file=file.path(data.output.dir, 'pooled_analysis_data.Rdata'))

run_deseq <- function(data.dt, ref_bin, test_bin, 
    control_replicates = T,                      
    interaction = F, group.control = F, weight.bins = F) 
{
  
  # identify inputs
  assay_input <- as.character(unique(data.dt[, assay]))
  k_type <- as.character(unique(data.dt[, k.type]))
  t_type <- as.character(unique(data.dt[, t.type]))
  
  # bin normalization weights
  data.weights <- unique(data.dt[, list(
    batch, donor, timepoint, assay, t.type, k.type, 
    sort.group, bin, bin.pct, bin.reads)])[, 
      list(bin, bin.reads, bin.pct, sort.group,
        read.weight=bin.pct * bin.reads / sum(bin.pct * bin.reads)),
      by=.(batch, donor, timepoint, assay, t.type, k.type)][, 
        read.weight.norm := read.weight/exp(mean(log(read.weight))), 
        by=.(batch, donor, timepoint, assay, t.type, k.type)]
  
  stopifnot(nrow(interaction(assay_input, k_type, t_type)) == 1)
  
  # prepare cts and coldata dataframes
  if(length(ref_bin) == 1 & ref_bin[1] == 'baseline') {
    
    # get baseline counts per donor/assay replicate
    ref.bin.dt <- dcast(
      data.dt[
        bin == 'D' & assay == assay_input &
          k.type == k_type & t.type == t_type, 
        .(
          CAR.align, 
          bin.sort.group = paste(
            batch, donor, timepoint, assay, t.type, 'base', sep = '_'), 
          k.type, t.type, batch, assay, donor, 
          sort.group, 
          bin = 'base',
          counts = baseline.counts)], 
      CAR.align ~ bin.sort.group, value.var='counts')
    
    ref.weights <- rep(1, ncol(ref.bin.dt))
    
  } else {
    
    ref.bin.dt <- dcast(
      data.dt[
        bin %in% ref_bin & assay == assay_input &
          k.type == k_type & t.type == t_type, 
        .(
          CAR.align, 
          bin.sort.group = paste(sort.group, bin, sep = '_'), 
          k.type, t.type, batch, assay, donor, 
          sort.group, 
          bin,
          counts)], 
      CAR.align ~ bin.sort.group, value.var='counts')
    
    ref.weights <- dcast(data.weights[
        bin %in% ref_bin & assay == assay_input &
            k.type == k_type & t.type == t_type, 
        .(
          bin.sort.group = paste(sort.group, bin, sep = '_'), 
          k.type, t.type, batch, assay, donor, 
          sort.group,
          bin,
          read.weight.norm)],
        . ~ bin.sort.group, 
        value.var = 'read.weight.norm')
    
    stopifnot(nrow(ref.bin.dt) == nrow(unique(ref.bin.dt)))
  }
  
  # copy the ref bin columns for each of the test bin columns
  if (interaction == T) {
    
    num.ref.reps <- ncol(ref.bin.dt) - 1
    
    ref.bin.dt <- cbind(ref.bin.dt[, 1], 
      do.call("cbind", replicate(length(test_bin), 
      ref.bin.dt[, -1], simplify = FALSE)))
    
    names(ref.bin.dt) <- c(names(ref.bin.dt[, 1]), 
      paste(names(ref.bin.dt[, -1]), rep(test_bin, each=num.ref.reps), 
        sep = '_'))
  }
  
  test.bin.dt <- dcast(
      data.dt[
        bin %in% test_bin & assay == assay_input &
          k.type == k_type & t.type == t_type, 
        .(
          CAR.align, 
          bin.sort.group = paste(sort.group, bin, sep = '_'), 
          k.type, t.type, batch, assay, donor, 
          sort.group, 
          bin,
          counts)], 
      CAR.align ~ bin.sort.group, value.var='counts')
  
  stopifnot(nrow(ref.bin.dt) == nrow(unique(ref.bin.dt)))
  stopifnot(nrow(test.bin.dt) == nrow(unique(test.bin.dt)))
  
  cts <- merge(ref.bin.dt, test.bin.dt, by = 'CAR.align')
  cts <- data.frame(cts[, -1], row.names = cts[, CAR.align])
  cts[is.na(cts)] <- 0
  
  coldata <- data.frame(
    condition = c(
      rep('reference', ncol(ref.bin.dt) - 1), 
      rep('test', ncol(test.bin.dt) - 1)), 
    rep = data.table(t(sapply(strsplit(c(
      names(ref.bin.dt)[-1], 
      names(test.bin.dt)[-1]),"_"), `[`, c(1,2))))[,
        paste(V1, V2, sep = '_'), by = seq(.N)]$V1,
    bin = sapply(strsplit(c(
      names(ref.bin.dt)[-1], 
      names(test.bin.dt)[-1]),"_"), `[`, 7),
    row.names = c(names(ref.bin.dt)[-1], names(test.bin.dt)[-1]))
  
  print(design(dds))
  print(coldata)
  
  dds <- DESeqDataSetFromMatrix(countData = cts,
                                colData = coldata,
                                design =  ~ condition + rep)
  
  # pre-filtering
  keep <- rowSums(counts(dds)) >= 10
  dds <- dds[keep,]
  
  if (control_replicates) {
      design(dds) <- ~ condition + rep
  }
  
  #check unique bins before using bins as contrast
  n_uniq_bins <- length(unique(coldata$bin))
  
  if (n_uniq_bins == 1 & interaction == T) {
    warning('Cannot use bin contrast, only one bin level.')
  }

  if (interaction == T & group.control == T & n_uniq_bins > 1) {
    design(dds) <- ~ condition + condition:rep + condition:bin
  } else if (interaction == T & n_uniq_bins > 1) {
    design(dds) <- ~ condition + rep + bin
  }
    
  # set reference
  dds$condition <- relevel(dds$condition, ref = "reference")
  
  # run DESeq
  dds <- DESeq(dds)
  res <- results(dds)
  
  # shrink log fold change
  resLFC <- lfcShrink(dds, coef="condition_test_vs_reference", type="apeglm")
  
  # convert to data.table
  results.dt <- as.data.table(resLFC)[, CAR.align := row.names(resLFC)]
  results.dt <- cbind(results.dt[, 6], results.dt[, -6])
  results.dt[, assay := assay_input][, k.type := k_type][, t.type := t_type]
  
  return(results.dt)
}

test_ref_sets <- c(ref=list(), test=list())

# A/B/AB vs C/D/CD
test_ref_sets$ref <- c(as.list(rep('D',3)), as.list(rep('C',3)), rep(list(c('C','D')),3))
test_ref_sets$test <- rep(list('A','B',c('A','B')),3)
test_ref_sets$interaction <- as.list(rep(F, 9))

# A/B/AB/ABCD vs baseline
test_ref_sets$ref <- c(test_ref_sets$ref, as.list(rep('baseline',3)))
test_ref_sets$test <- c(test_ref_sets$test, list('A','B',c('A','B','C','D')))
test_ref_sets$interaction <- c(test_ref_sets$interaction, as.list(rep(T, 3)))

all.deseq.results.dt <- data.table()

for (set_i in seq_along(test_ref_sets$ref)) {
    
  ref_set <- test_ref_sets$ref[[set_i]]
  test_set <- test_ref_sets$test[[set_i]]
  
  ref_str <- paste0(ref_set, collapse='')
  test_str <- paste0(test_set, collapse='')
  
  inter <- test_ref_sets$interaction[[set_i]]
  
  deseq.results.dt <- read.counts[batch != 'post-cytof' & !is.na(k.type),
    {
      message(paste(c(ref_str, test_str, inter, .BY[1]), collapse= ' - '));
      tryCatch(
        run_deseq(
          data.dt = .SD, 
          ref_bin = ref_set,
          test_bin = test_set,
          interaction = inter),
        error= function(e) {e; return(data.table())}
      )
    },
    by = .(group)]
  
  if (nrow(deseq.results.dt) > 0) {
    deseq.results.dt[,
      `:=`(
        ref_set = ref_str,
        test_set = test_str,
        inter = inter)]
  }
  
  all.deseq.results.dt <- rbind(
    all.deseq.results.dt,
    deseq.results.dt)
}

save(list=c('all.deseq.results.dt'),
     file=file.path(data.output.dir, 'pooled_deseq2_data.Rdata'))
if (!rerun_deseq) load(
  file=file.path(data.output.dir, 'pooled_deseq2_data.Rdata'))


all.deseq.results.dt[, padj.disp := -log10(padj)]
all.deseq.results.dt[, lfc.disp := log2FoldChange]
all.deseq.results.dt[padj.disp > 10, padj.disp := Inf]
all.deseq.results.dt[abs(lfc.disp) > 5, lfc.disp := sign(lfc.disp) * Inf]

# mask receptor names except for known ones
control_domains <- c('41BB','CD28')
chosen_domains <- c('BAFF-R','CD40','TACI','TNR8')
neg_domain <- c('KLRG1')

all.deseq.results.dt[, CAR.type := 'other']
all.deseq.results.dt[CAR.align %in% control_domains, CAR.type := 'control']
all.deseq.results.dt[CAR.align %in% chosen_domains, CAR.type := 'chosen']
all.deseq.results.dt[CAR.align %in% neg_domain, CAR.type := 'neg']
all.deseq.results.dt[, 
  CAR.type := factor(CAR.type,levels=c('other','control','chosen','neg'))]

make_volcanoes <- function(data.dt) {
  ggplot(data.dt, aes(
    x=lfc.disp, y=padj.disp, 
    color=CAR.type,
    label=CAR.align,
    size=CAR.type)) + 
  geom_point() +
  facet_grid(test_set + ref_set ~ t.type + assay + k.type) +
  scale_color_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c('grey50', RColorBrewer::brewer.pal(5, 'Paired')[c(2,4,5)])) +
  scale_size_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c(1,3,3,3)) +
  labs(x='Log2 FC', y='-log10(P-value)', title='Assay Volcano Plots')
}

make_timeseries <- function(data.dt) {
  ggplot(data.dt, aes(
    y=lfc.disp, x=assay, 
    color=CAR.type,
    group=CAR.align,
    label=CAR.align,
    size=CAR.type)) + 
    geom_point() +
    geom_line() +
    facet_grid(t.type ~ test_set + ref_set) +
    scale_color_manual('',
      labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
      values=c('grey50', RColorBrewer::brewer.pal(5, 'Paired')[c(2,4,5)])) +
    scale_size_manual('',
      labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
      values=c(0.5,1,1,1)) +
    labs(y='Log2 FC', x='Assay', title='Log fold change across assays')

}

make_cd4_cd8 <- function(data.dt) {
  ggplot(
    dcast(data.dt, 
    CAR.align + assay + k.type + ref_set + test_set + inter + CAR.type ~ t.type, 
    value.var = c("log2FoldChange", "padj")), 
  aes(y=log2FoldChange_CD8, x=log2FoldChange_CD4,
      color=CAR.type,
      label=CAR.align,
      size=CAR.type)) + 
  geom_point() +
  facet_grid(test_set + ref_set ~  assay + k.type) +
  scale_color_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c('grey50', RColorBrewer::brewer.pal(5, 'Paired')[c(2,4,5)])) +
  scale_size_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c(1,3,3,3)) +
  labs(x='CD4', y='CD8', title='Log fold change, CD4 vs CD8')
}

make_pos_neg <- function(data.dt) {
  ggplot(
    dcast(data.dt, 
    CAR.align + assay + t.type + ref_set + test_set + inter + CAR.type ~ k.type, 
    value.var = c("lfc.disp", "padj")), 
  aes(y=lfc.disp_pos, x=lfc.disp_neg,
      color=CAR.type,
      label=CAR.align,
      size=CAR.type)) + 
  geom_point() +
  facet_grid(test_set + ref_set ~ t.type + assay) +
  scale_color_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c('grey50', RColorBrewer::brewer.pal(5, 'Paired')[c(2,4,5)])) +
  scale_size_manual('',
    labels=c('Other Receptors', 'CD28/41BB', 'New Receptors','Negative'),
    values=c(1,3,3,3)) +
  labs(x='CD19-', y='CD19+', title='Log fold change, CD19+ vs CD19-')
}

Bin vs baseline measurements

ggplotly(make_volcanoes(all.deseq.results.dt[
  k.type == 'pos' & ref_set == 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_cd4_cd8(all.deseq.results.dt[
  k.type == 'pos' & ref_set == 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_timeseries(all.deseq.results.dt[
  k.type == 'pos' & ref_set == 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_pos_neg(all.deseq.results.dt[
  ref_set == 'baseline']), 
  tooltip = "label", session='knitr')

Interbin measurements

ggplotly(make_volcanoes(all.deseq.results.dt[
  k.type == 'pos' & ref_set != 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_cd4_cd8(all.deseq.results.dt[
  k.type == 'pos' & ref_set != 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_timeseries(all.deseq.results.dt[
  k.type == 'pos' & ref_set != 'baseline']), 
  tooltip = "label", session='knitr')

cd19+ vs cd19-

ggplotly(make_pos_neg(all.deseq.results.dt[
  ref_set == 'baseline']), 
  tooltip = "label", session='knitr')
ggplotly(make_pos_neg(all.deseq.results.dt[
  ref_set != 'baseline']), 
  tooltip = "label", session='knitr')